{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "v1CUZ0dkOo_F"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "qmkj-80IHxnd"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_xnMOsbqHz61"
      },
      "source": [
        "# Pix2Pix"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ds4o1h4WHz9U"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/tutorials/generative/pix2pix\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/generative/pix2pix.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/tutorials/generative/pix2pix.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/tutorials/generative/pix2pix.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ITZuApL56Mny"
      },
      "source": [
        "This notebook demonstrates image to image translation using conditional GAN's, as described in [Image-to-Image Translation with Conditional Adversarial Networks](https://arxiv.org/abs/1611.07004). Using this technique we can colorize black and white photos, convert google maps to google earth, etc. Here, we convert building facades to real buildings.\n",
        "\n",
        "In example, we will use the [CMP Facade Database](http://cmp.felk.cvut.cz/~tylecr1/facade/), helpfully provided by the [Center for Machine Perception](http://cmp.felk.cvut.cz/) at the [Czech Technical University in Prague](https://www.cvut.cz/). To keep our example short, we will use a preprocessed [copy](https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/) of this dataset, created by the authors of the [paper](https://arxiv.org/abs/1611.07004) above.\n",
        "\n",
        "Each epoch takes around 15 seconds on a single V100 GPU.\n",
        "\n",
        "Below is the output generated after training the model for 200 epochs.\n",
        "\n",
        "![sample output_1](https://www.tensorflow.org/images/gan/pix2pix_1.png)\n",
        "![sample output_2](https://www.tensorflow.org/images/gan/pix2pix_2.png)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e1_Y75QXJS6h"
      },
      "source": [
        "## Import TensorFlow and other libraries"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YfIk2es3hJEd"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "\n",
        "import os\n",
        "import time\n",
        "\n",
        "from matplotlib import pyplot as plt\n",
        "from IPython import display"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wifwThPoEj7e"
      },
      "outputs": [],
      "source": [
        "!pip install -U tensorboard"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iYn4MdZnKCey"
      },
      "source": [
        "## Load the dataset\n",
        "\n",
        "You can download this dataset and similar datasets from [here](https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets). As mentioned in the [paper](https://arxiv.org/abs/1611.07004) we apply random jittering and mirroring to the training dataset.\n",
        "\n",
        "* In random jittering, the image is resized to `286 x 286` and then randomly cropped to `256 x 256`\n",
        "* In random mirroring, the image is randomly flipped horizontally i.e left to right."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Kn-k8kTXuAlv"
      },
      "outputs": [],
      "source": [
        "_URL = 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/facades.tar.gz'\n",
        "\n",
        "path_to_zip = tf.keras.utils.get_file('facades.tar.gz',\n",
        "                                      origin=_URL,\n",
        "                                      extract=True)\n",
        "\n",
        "PATH = os.path.join(os.path.dirname(path_to_zip), 'facades/')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2CbTEt448b4R"
      },
      "outputs": [],
      "source": [
        "BUFFER_SIZE = 400\n",
        "BATCH_SIZE = 1\n",
        "IMG_WIDTH = 256\n",
        "IMG_HEIGHT = 256"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aO9ZAGH5K3SY"
      },
      "outputs": [],
      "source": [
        "def load(image_file):\n",
        "  image = tf.io.read_file(image_file)\n",
        "  image = tf.image.decode_jpeg(image)\n",
        "\n",
        "  w = tf.shape(image)[1]\n",
        "\n",
        "  w = w // 2\n",
        "  real_image = image[:, :w, :]\n",
        "  input_image = image[:, w:, :]\n",
        "\n",
        "  input_image = tf.cast(input_image, tf.float32)\n",
        "  real_image = tf.cast(real_image, tf.float32)\n",
        "\n",
        "  return input_image, real_image"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4OLHMpsQ5aOv"
      },
      "outputs": [],
      "source": [
        "inp, re = load(PATH+'train/100.jpg')\n",
        "# casting to int for matplotlib to show the image\n",
        "plt.figure()\n",
        "plt.imshow(inp/255.0)\n",
        "plt.figure()\n",
        "plt.imshow(re/255.0)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rwwYQpu9FzDu"
      },
      "outputs": [],
      "source": [
        "def resize(input_image, real_image, height, width):\n",
        "  input_image = tf.image.resize(input_image, [height, width],\n",
        "                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n",
        "  real_image = tf.image.resize(real_image, [height, width],\n",
        "                               method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n",
        "\n",
        "  return input_image, real_image"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Yn3IwqhiIszt"
      },
      "outputs": [],
      "source": [
        "def random_crop(input_image, real_image):\n",
        "  stacked_image = tf.stack([input_image, real_image], axis=0)\n",
        "  cropped_image = tf.image.random_crop(\n",
        "      stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])\n",
        "\n",
        "  return cropped_image[0], cropped_image[1]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "muhR2cgbLKWW"
      },
      "outputs": [],
      "source": [
        "# normalizing the images to [-1, 1]\n",
        "\n",
        "def normalize(input_image, real_image):\n",
        "  input_image = (input_image / 127.5) - 1\n",
        "  real_image = (real_image / 127.5) - 1\n",
        "\n",
        "  return input_image, real_image"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fVQOjcPVLrUc"
      },
      "outputs": [],
      "source": [
        "@tf.function()\n",
        "def random_jitter(input_image, real_image):\n",
        "  # resizing to 286 x 286 x 3\n",
        "  input_image, real_image = resize(input_image, real_image, 286, 286)\n",
        "\n",
        "  # randomly cropping to 256 x 256 x 3\n",
        "  input_image, real_image = random_crop(input_image, real_image)\n",
        "\n",
        "  if tf.random.uniform(()) > 0.5:\n",
        "    # random mirroring\n",
        "    input_image = tf.image.flip_left_right(input_image)\n",
        "    real_image = tf.image.flip_left_right(real_image)\n",
        "\n",
        "  return input_image, real_image"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wfAQbzy799UV"
      },
      "source": [
        "As you can see in the images below\n",
        "that they are going through random jittering\n",
        "Random jittering as described in the paper is to\n",
        "\n",
        "1. Resize an image to bigger height and width\n",
        "2. Randomly crop to the target size\n",
        "3. Randomly flip the image horizontally"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "n0OGdi6D92kM"
      },
      "outputs": [],
      "source": [
        "plt.figure(figsize=(6, 6))\n",
        "for i in range(4):\n",
        "  rj_inp, rj_re = random_jitter(inp, re)\n",
        "  plt.subplot(2, 2, i+1)\n",
        "  plt.imshow(rj_inp/255.0)\n",
        "  plt.axis('off')\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tyaP4hLJ8b4W"
      },
      "outputs": [],
      "source": [
        "def load_image_train(image_file):\n",
        "  input_image, real_image = load(image_file)\n",
        "  input_image, real_image = random_jitter(input_image, real_image)\n",
        "  input_image, real_image = normalize(input_image, real_image)\n",
        "\n",
        "  return input_image, real_image"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VB3Z6D_zKSru"
      },
      "outputs": [],
      "source": [
        "def load_image_test(image_file):\n",
        "  input_image, real_image = load(image_file)\n",
        "  input_image, real_image = resize(input_image, real_image,\n",
        "                                   IMG_HEIGHT, IMG_WIDTH)\n",
        "  input_image, real_image = normalize(input_image, real_image)\n",
        "\n",
        "  return input_image, real_image"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PIGN6ouoQxt3"
      },
      "source": [
        "## Input Pipeline"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SQHmYSmk8b4b"
      },
      "outputs": [],
      "source": [
        "train_dataset = tf.data.Dataset.list_files(PATH+'train/*.jpg')\n",
        "train_dataset = train_dataset.map(load_image_train,\n",
        "                                  num_parallel_calls=tf.data.AUTOTUNE)\n",
        "train_dataset = train_dataset.shuffle(BUFFER_SIZE)\n",
        "train_dataset = train_dataset.batch(BATCH_SIZE)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MS9J0yA58b4g"
      },
      "outputs": [],
      "source": [
        "test_dataset = tf.data.Dataset.list_files(PATH+'test/*.jpg')\n",
        "test_dataset = test_dataset.map(load_image_test)\n",
        "test_dataset = test_dataset.batch(BATCH_SIZE)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "THY-sZMiQ4UV"
      },
      "source": [
        "## Build the Generator\n",
        "  * The architecture of generator is a modified U-Net.\n",
        "  * Each block in the encoder is (Conv -> Batchnorm -> Leaky ReLU)\n",
        "  * Each block in the decoder is (Transposed Conv -> Batchnorm -> Dropout(applied to the first 3 blocks) -> ReLU)\n",
        "  * There are skip connections between the encoder and decoder (as in U-Net).\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tqqvWxlw8b4l"
      },
      "outputs": [],
      "source": [
        "OUTPUT_CHANNELS = 3"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3R09ATE_SH9P"
      },
      "outputs": [],
      "source": [
        "def downsample(filters, size, apply_batchnorm=True):\n",
        "  initializer = tf.random_normal_initializer(0., 0.02)\n",
        "\n",
        "  result = tf.keras.Sequential()\n",
        "  result.add(\n",
        "      tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',\n",
        "                             kernel_initializer=initializer, use_bias=False))\n",
        "\n",
        "  if apply_batchnorm:\n",
        "    result.add(tf.keras.layers.BatchNormalization())\n",
        "\n",
        "  result.add(tf.keras.layers.LeakyReLU())\n",
        "\n",
        "  return result"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "a6_uCZCppTh7"
      },
      "outputs": [],
      "source": [
        "down_model = downsample(3, 4)\n",
        "down_result = down_model(tf.expand_dims(inp, 0))\n",
        "print (down_result.shape)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nhgDsHClSQzP"
      },
      "outputs": [],
      "source": [
        "def upsample(filters, size, apply_dropout=False):\n",
        "  initializer = tf.random_normal_initializer(0., 0.02)\n",
        "\n",
        "  result = tf.keras.Sequential()\n",
        "  result.add(\n",
        "    tf.keras.layers.Conv2DTranspose(filters, size, strides=2,\n",
        "                                    padding='same',\n",
        "                                    kernel_initializer=initializer,\n",
        "                                    use_bias=False))\n",
        "\n",
        "  result.add(tf.keras.layers.BatchNormalization())\n",
        "\n",
        "  if apply_dropout:\n",
        "      result.add(tf.keras.layers.Dropout(0.5))\n",
        "\n",
        "  result.add(tf.keras.layers.ReLU())\n",
        "\n",
        "  return result"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mz-ahSdsq0Oc"
      },
      "outputs": [],
      "source": [
        "up_model = upsample(3, 4)\n",
        "up_result = up_model(down_result)\n",
        "print (up_result.shape)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lFPI4Nu-8b4q"
      },
      "outputs": [],
      "source": [
        "def Generator():\n",
        "  inputs = tf.keras.layers.Input(shape=[256,256,3])\n",
        "\n",
        "  down_stack = [\n",
        "    downsample(64, 4, apply_batchnorm=False), # (bs, 128, 128, 64)\n",
        "    downsample(128, 4), # (bs, 64, 64, 128)\n",
        "    downsample(256, 4), # (bs, 32, 32, 256)\n",
        "    downsample(512, 4), # (bs, 16, 16, 512)\n",
        "    downsample(512, 4), # (bs, 8, 8, 512)\n",
        "    downsample(512, 4), # (bs, 4, 4, 512)\n",
        "    downsample(512, 4), # (bs, 2, 2, 512)\n",
        "    downsample(512, 4), # (bs, 1, 1, 512)\n",
        "  ]\n",
        "\n",
        "  up_stack = [\n",
        "    upsample(512, 4, apply_dropout=True), # (bs, 2, 2, 1024)\n",
        "    upsample(512, 4, apply_dropout=True), # (bs, 4, 4, 1024)\n",
        "    upsample(512, 4, apply_dropout=True), # (bs, 8, 8, 1024)\n",
        "    upsample(512, 4), # (bs, 16, 16, 1024)\n",
        "    upsample(256, 4), # (bs, 32, 32, 512)\n",
        "    upsample(128, 4), # (bs, 64, 64, 256)\n",
        "    upsample(64, 4), # (bs, 128, 128, 128)\n",
        "  ]\n",
        "\n",
        "  initializer = tf.random_normal_initializer(0., 0.02)\n",
        "  last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,\n",
        "                                         strides=2,\n",
        "                                         padding='same',\n",
        "                                         kernel_initializer=initializer,\n",
        "                                         activation='tanh') # (bs, 256, 256, 3)\n",
        "\n",
        "  x = inputs\n",
        "\n",
        "  # Downsampling through the model\n",
        "  skips = []\n",
        "  for down in down_stack:\n",
        "    x = down(x)\n",
        "    skips.append(x)\n",
        "\n",
        "  skips = reversed(skips[:-1])\n",
        "\n",
        "  # Upsampling and establishing the skip connections\n",
        "  for up, skip in zip(up_stack, skips):\n",
        "    x = up(x)\n",
        "    x = tf.keras.layers.Concatenate()([x, skip])\n",
        "\n",
        "  x = last(x)\n",
        "\n",
        "  return tf.keras.Model(inputs=inputs, outputs=x)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dIbRPFzjmV85"
      },
      "outputs": [],
      "source": [
        "generator = Generator()\n",
        "tf.keras.utils.plot_model(generator, show_shapes=True, dpi=64)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "U1N1_obwtdQH"
      },
      "outputs": [],
      "source": [
        "gen_output = generator(inp[tf.newaxis,...], training=False)\n",
        "plt.imshow(gen_output[0,...])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dpDPEQXIAiQO"
      },
      "source": [
        "* **Generator loss**\n",
        "  * It is a sigmoid cross entropy loss of the generated images and an **array of ones**.\n",
        "  * The [paper](https://arxiv.org/abs/1611.07004) also includes L1 loss which is MAE (mean absolute error) between the generated image and the target image.\n",
        "  * This allows the generated image to become structurally similar to the target image.\n",
        "  * The formula to calculate the total generator loss = gan_loss + LAMBDA * l1_loss, where LAMBDA = 100. This value was decided by the authors of the [paper](https://arxiv.org/abs/1611.07004)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fSZbDgESHIV6"
      },
      "source": [
        "The training procedure for the generator is shown below:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cyhxTuvJyIHV"
      },
      "outputs": [],
      "source": [
        "LAMBDA = 100"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "90BIcCKcDMxz"
      },
      "outputs": [],
      "source": [
        "def generator_loss(disc_generated_output, gen_output, target):\n",
        "  gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)\n",
        "\n",
        "  # mean absolute error\n",
        "  l1_loss = tf.reduce_mean(tf.abs(target - gen_output))\n",
        "\n",
        "  total_gen_loss = gan_loss + (LAMBDA * l1_loss)\n",
        "\n",
        "  return total_gen_loss, gan_loss, l1_loss"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TlB-XMY5Awj9"
      },
      "source": [
        "![Generator Update Image](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/generative/images/gen.png?raw=1)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZTKZfoaoEF22"
      },
      "source": [
        "## Build the Discriminator\n",
        "  * The Discriminator is a PatchGAN.\n",
        "  * Each block in the discriminator is (Conv -> BatchNorm -> Leaky ReLU)\n",
        "  * The shape of the output after the last layer is (batch_size, 30, 30, 1)\n",
        "  * Each 30x30 patch of the output classifies a 70x70 portion of the input image (such an architecture is called a PatchGAN).\n",
        "  * Discriminator receives 2 inputs.\n",
        "    * Input image and the target image, which it should classify as real.\n",
        "    * Input image and the generated image (output of generator), which it should classify as fake.\n",
        "    * We concatenate these 2 inputs together in the code (`tf.concat([inp, tar], axis=-1)`)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ll6aNeQx8b4v"
      },
      "outputs": [],
      "source": [
        "def Discriminator():\n",
        "  initializer = tf.random_normal_initializer(0., 0.02)\n",
        "\n",
        "  inp = tf.keras.layers.Input(shape=[256, 256, 3], name='input_image')\n",
        "  tar = tf.keras.layers.Input(shape=[256, 256, 3], name='target_image')\n",
        "\n",
        "  x = tf.keras.layers.concatenate([inp, tar]) # (bs, 256, 256, channels*2)\n",
        "\n",
        "  down1 = downsample(64, 4, False)(x) # (bs, 128, 128, 64)\n",
        "  down2 = downsample(128, 4)(down1) # (bs, 64, 64, 128)\n",
        "  down3 = downsample(256, 4)(down2) # (bs, 32, 32, 256)\n",
        "\n",
        "  zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256)\n",
        "  conv = tf.keras.layers.Conv2D(512, 4, strides=1,\n",
        "                                kernel_initializer=initializer,\n",
        "                                use_bias=False)(zero_pad1) # (bs, 31, 31, 512)\n",
        "\n",
        "  batchnorm1 = tf.keras.layers.BatchNormalization()(conv)\n",
        "\n",
        "  leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)\n",
        "\n",
        "  zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)\n",
        "\n",
        "  last = tf.keras.layers.Conv2D(1, 4, strides=1,\n",
        "                                kernel_initializer=initializer)(zero_pad2) # (bs, 30, 30, 1)\n",
        "\n",
        "  return tf.keras.Model(inputs=[inp, tar], outputs=last)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YHoUui4om-Ev"
      },
      "outputs": [],
      "source": [
        "discriminator = Discriminator()\n",
        "tf.keras.utils.plot_model(discriminator, show_shapes=True, dpi=64)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gDkA05NE6QMs"
      },
      "outputs": [],
      "source": [
        "disc_out = discriminator([inp[tf.newaxis,...], gen_output], training=False)\n",
        "plt.imshow(disc_out[0,...,-1], vmin=-20, vmax=20, cmap='RdBu_r')\n",
        "plt.colorbar()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AOqg1dhUAWoD"
      },
      "source": [
        "**Discriminator loss**\n",
        "  * The discriminator loss function takes 2 inputs; **real images, generated images**\n",
        "  * real_loss is a sigmoid cross entropy loss of the **real images** and an **array of ones(since these are the real images)**\n",
        "  * generated_loss is a sigmoid cross entropy loss of the **generated images** and an **array of zeros(since these are the fake images)**\n",
        "  * Then the total_loss is the sum of real_loss and the generated_loss\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Q1Xbz5OaLj5C"
      },
      "outputs": [],
      "source": [
        "loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wkMNfBWlT-PV"
      },
      "outputs": [],
      "source": [
        "def discriminator_loss(disc_real_output, disc_generated_output):\n",
        "  real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)\n",
        "\n",
        "  generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)\n",
        "\n",
        "  total_disc_loss = real_loss + generated_loss\n",
        "\n",
        "  return total_disc_loss"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-ede4p2YELFa"
      },
      "source": [
        "The training procedure for the discriminator is shown below.\n",
        "\n",
        "To learn more about the architecture and the hyperparameters you can refer the [paper](https://arxiv.org/abs/1611.07004)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IS9sHa-1BoAF"
      },
      "source": [
        "![Discriminator Update Image](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/generative/images/dis.png?raw=1)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0FMYgY_mPfTi"
      },
      "source": [
        "## Define the Optimizers and Checkpoint-saver\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lbHFNexF0x6O"
      },
      "outputs": [],
      "source": [
        "generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)\n",
        "discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WJnftd5sQsv6"
      },
      "outputs": [],
      "source": [
        "checkpoint_dir = './training_checkpoints'\n",
        "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n",
        "checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,\n",
        "                                 discriminator_optimizer=discriminator_optimizer,\n",
        "                                 generator=generator,\n",
        "                                 discriminator=discriminator)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Rw1fkAczTQYh"
      },
      "source": [
        "## Generate Images\n",
        "\n",
        "Write a function to plot some images during training.\n",
        "\n",
        "* We pass images from the test dataset to the generator.\n",
        "* The generator will then translate the input image into the output.\n",
        "* Last step is to plot the predictions and **voila!**"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Rb0QQFHF-JfS"
      },
      "source": [
        "Note: The `training=True` is intentional here since\n",
        "we want the batch statistics while running the model\n",
        "on the test dataset. If we use training=False, we will get\n",
        "the accumulated statistics learned from the training dataset\n",
        "(which we don't want)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RmdVsmvhPxyy"
      },
      "outputs": [],
      "source": [
        "def generate_images(model, test_input, tar):\n",
        "  prediction = model(test_input, training=True)\n",
        "  plt.figure(figsize=(15,15))\n",
        "\n",
        "  display_list = [test_input[0], tar[0], prediction[0]]\n",
        "  title = ['Input Image', 'Ground Truth', 'Predicted Image']\n",
        "\n",
        "  for i in range(3):\n",
        "    plt.subplot(1, 3, i+1)\n",
        "    plt.title(title[i])\n",
        "    # getting the pixel values between [0, 1] to plot it.\n",
        "    plt.imshow(display_list[i] * 0.5 + 0.5)\n",
        "    plt.axis('off')\n",
        "  plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8Fc4NzT-DgEx"
      },
      "outputs": [],
      "source": [
        "for example_input, example_target in test_dataset.take(1):\n",
        "  generate_images(generator, example_input, example_target)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NLKOG55MErD0"
      },
      "source": [
        "## Training\n",
        "\n",
        "* For each example input generate an output.\n",
        "* The discriminator receives the input_image and the generated image as the first input. The second input is the input_image and the target_image.\n",
        "* Next, we calculate the generator and the discriminator loss.\n",
        "* Then, we calculate the gradients of loss with respect to both the generator and the discriminator variables(inputs) and apply those to the optimizer.\n",
        "* Then log the losses to TensorBoard."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NS2GWywBbAWo"
      },
      "outputs": [],
      "source": [
        "EPOCHS = 150"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xNNMDBNH12q-"
      },
      "outputs": [],
      "source": [
        "import datetime\n",
        "log_dir=\"logs/\"\n",
        "\n",
        "summary_writer = tf.summary.create_file_writer(\n",
        "  log_dir + \"fit/\" + datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\"))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KBKUV2sKXDbY"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def train_step(input_image, target, epoch):\n",
        "  with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:\n",
        "    gen_output = generator(input_image, training=True)\n",
        "\n",
        "    disc_real_output = discriminator([input_image, target], training=True)\n",
        "    disc_generated_output = discriminator([input_image, gen_output], training=True)\n",
        "\n",
        "    gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target)\n",
        "    disc_loss = discriminator_loss(disc_real_output, disc_generated_output)\n",
        "\n",
        "  generator_gradients = gen_tape.gradient(gen_total_loss,\n",
        "                                          generator.trainable_variables)\n",
        "  discriminator_gradients = disc_tape.gradient(disc_loss,\n",
        "                                               discriminator.trainable_variables)\n",
        "\n",
        "  generator_optimizer.apply_gradients(zip(generator_gradients,\n",
        "                                          generator.trainable_variables))\n",
        "  discriminator_optimizer.apply_gradients(zip(discriminator_gradients,\n",
        "                                              discriminator.trainable_variables))\n",
        "\n",
        "  with summary_writer.as_default():\n",
        "    tf.summary.scalar('gen_total_loss', gen_total_loss, step=epoch)\n",
        "    tf.summary.scalar('gen_gan_loss', gen_gan_loss, step=epoch)\n",
        "    tf.summary.scalar('gen_l1_loss', gen_l1_loss, step=epoch)\n",
        "    tf.summary.scalar('disc_loss', disc_loss, step=epoch)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hx7s-vBHFKdh"
      },
      "source": [
        "The actual training loop:\n",
        "\n",
        "* Iterates over the number of epochs.\n",
        "* On each epoch it clears the display, and runs `generate_images` to show it's progress.\n",
        "* On each epoch it iterates over the training dataset, printing a '.' for each example.\n",
        "* It saves a checkpoint every 20 epochs."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2M7LmLtGEMQJ"
      },
      "outputs": [],
      "source": [
        "def fit(train_ds, epochs, test_ds):\n",
        "  for epoch in range(epochs):\n",
        "    start = time.time()\n",
        "\n",
        "    display.clear_output(wait=True)\n",
        "\n",
        "    for example_input, example_target in test_ds.take(1):\n",
        "      generate_images(generator, example_input, example_target)\n",
        "    print(\"Epoch: \", epoch)\n",
        "\n",
        "    # Train\n",
        "    for n, (input_image, target) in train_ds.enumerate():\n",
        "      print('.', end='')\n",
        "      if (n+1) % 100 == 0:\n",
        "        print()\n",
        "      train_step(input_image, target, epoch)\n",
        "    print()\n",
        "\n",
        "    # saving (checkpoint) the model every 20 epochs\n",
        "    if (epoch + 1) % 20 == 0:\n",
        "      checkpoint.save(file_prefix = checkpoint_prefix)\n",
        "\n",
        "    print ('Time taken for epoch {} is {} sec\\n'.format(epoch + 1,\n",
        "                                                        time.time()-start))\n",
        "  checkpoint.save(file_prefix = checkpoint_prefix)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wozqyTh2wmCu"
      },
      "source": [
        "This training loop saves logs you can easily view in TensorBoard to monitor the training progress. Working locally you would launch a separate tensorboard process. In a notebook, if you want to monitor with TensorBoard it's easiest to launch the viewer before starting the training.\n",
        "\n",
        "To launch the viewer paste the following into a code-cell:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ot22ujrlLhOd"
      },
      "outputs": [],
      "source": [
        "#docs_infra: no_execute\n",
        "%load_ext tensorboard\n",
        "%tensorboard --logdir {log_dir}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Pe0-8Bzg22ox"
      },
      "source": [
        "Now run the training loop:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "a1zZmKmvOH85"
      },
      "outputs": [],
      "source": [
        "fit(train_dataset, EPOCHS, test_dataset)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oeq9sByu86-B"
      },
      "source": [
        "If you want to share the TensorBoard results _publicly_ you can upload the logs to [TensorBoard.dev](https://tensorboard.dev/) by copying the following into a code-cell.\n",
        "\n",
        "Note: This requires a Google account.\n",
        "\n",
        "```\n",
        "!tensorboard dev upload --logdir  {log_dir}\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "l-kT7WHRKz-E"
      },
      "source": [
        "Caution: This command does not terminate. It's designed to continuously upload the results of long-running experiments. Once your data is uploaded you need to stop it using the \"interrupt execution\" option in your notebook tool."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-lGhS_LfwQoL"
      },
      "source": [
        "You can view the [results of a previous run](https://tensorboard.dev/experiment/lZ0C6FONROaUMfjYkVyJqw) of this notebook on [TensorBoard.dev](https://tensorboard.dev/).\n",
        "\n",
        "TensorBoard.dev is a managed experience for hosting, tracking, and sharing ML experiments with everyone.\n",
        "\n",
        "It can also included inline using an `<iframe>`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8IS4c93guQ8E"
      },
      "outputs": [],
      "source": [
        "display.IFrame(\n",
        "    src=\"https://tensorboard.dev/experiment/lZ0C6FONROaUMfjYkVyJqw\",\n",
        "    width=\"100%\",\n",
        "    height=\"1000px\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DMTm4peo3cem"
      },
      "source": [
        "Interpreting the logs from a GAN is more subtle than a simple classification or regression model. Things to look for::\n",
        "\n",
        "* Check that neither model has \"won\". If either the `gen_gan_loss` or the `disc_loss` gets very low it's an indicator that this model is dominating the other, and you are not successfully training the combined model.\n",
        "* The value `log(2) = 0.69` is a good reference point for these losses, as it indicates a perplexity of 2: That the discriminator is on average equally uncertain about the two options.\n",
        "* For the `disc_loss` a value below `0.69` means the discriminator is doing better than random, on the combined set of real+generated images.\n",
        "* For the `gen_gan_loss` a value below `0.69` means the generator is doing better than random at fooling the descriminator.\n",
        "* As training progresses the `gen_l1_loss` should go down."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kz80bY3aQ1VZ"
      },
      "source": [
        "## Restore the latest checkpoint and test"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HSSm4kfvJiqv"
      },
      "outputs": [],
      "source": [
        "!ls {checkpoint_dir}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4t4x69adQ5xb"
      },
      "outputs": [],
      "source": [
        "# restoring the latest checkpoint in checkpoint_dir\n",
        "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1RGysMU_BZhx"
      },
      "source": [
        "## Generate using test dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KUgSnmy2nqSP"
      },
      "outputs": [],
      "source": [
        "# Run the trained model on a few examples from the test dataset\n",
        "for inp, tar in test_dataset.take(5):\n",
        "  generate_images(generator, inp, tar)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "pix2pix.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
